import numpy as np
import pandas as pd
import re
from scipy.optimize import minimize
from tqdm import tqdm
from joblib import Parallel, delayed
import logging
import time
import matplotlib.pyplot as plt

# Primes list
PRIMES = [
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
    73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
    157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233,
    239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317,
    331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419,
    421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503,
    509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607,
    613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
    709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811,
    821, 823, 827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997
]

phi = (1 + np.sqrt(5)) / 2
fib_cache = {}

def fib_real(n):
    if n in fib_cache:
        return fib_cache[n]
    if n > 100:
        return 0.0
    term1 = phi**n / np.sqrt(5)
    term2 = ((1/phi)**n) * np.cos(np.pi * n)
    result = term1 - term2
    fib_cache[n] = result
    return result

def D(n, beta, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0):
    try:
        Fn_beta = fib_real(n + beta)
        idx = int(np.floor(n + beta) + len(PRIMES)) % len(PRIMES)
        Pn_beta = PRIMES[idx]
        dyadic = np.exp((n + beta) * np.log(base))  # Logarithmic to avoid overflow
        val = scale * phi * Fn_beta * dyadic * Pn_beta * Omega
        if n > 1000:
            val *= np.log(n) / np.log(1000)
        if not np.isfinite(val):
            return None
        return np.sqrt(max(val, 1e-30)) * (r ** k)
    except Exception as e:
        logging.error(f"D failed: n={n}, beta={beta}, error={e}")
        return None

def invert_D(value, r=1.0, k=1.0, Omega=1.0, base=2, scale=1.0, max_n=500):
    candidates = []
    log_val = np.log10(max(abs(value), 1e-30))
    max_n = min(1000, max(500, int(200 * log_val)))
    n_values = np.linspace(0, max_n, 50)
    scale_factors = np.logspace(max(log_val - 2, -10), min(log_val + 2, 10), num=10)
    try:
        for n in tqdm(n_values, desc=f"invert_D for {value:.2e}", leave=False):
            for beta in np.linspace(0, 1, 5):
                for dynamic_scale in scale_factors:
                    for r_local in [0.5, 1.0]:
                        for k_local in [0.5, 1.0]:
                            val = D(n, beta, r_local, k_local, Omega, base, scale * dynamic_scale)
                            if val is None:
                                continue
                            diff = abs(val - value)
                            candidates.append((diff, n, beta, dynamic_scale, r_local, k_local))
        if not candidates:
            logging.error(f"invert_D: No valid candidates for value {value}")
            return None
        candidates = sorted(candidates, key=lambda x: x[0])[:20]
        best = candidates[0]
        emergent_uncertainty = np.std([D(n, beta, r, k, Omega, base, scale * s) 
                                      for _, n, beta, s, r, k in candidates if D(n, beta, r, k, Omega, base, scale * s) is not None])
        if not np.isfinite(emergent_uncertainty):
            logging.error(f"invert_D: Non-finite emergent uncertainty for value {value}")
            return None
        return best[1], best[2], best[3], emergent_uncertainty, best[4], best[5]
    except Exception as e:
        logging.error(f"invert_D failed for value {value}: {e}")
        return None

def parse_codata_ascii(filename):
    constants = []
    pattern = re.compile(r"^\s*(.*?)\s{2,}(\-?\d+\.?\d*(?:\s*[Ee][\+\-]?\d+)?(?:\.\.\.)?)\s+(\-?\d+\.?\d*(?:\s*[Ee][\+\-]?\d+)?|exact)\s+(\S.*)")
    with open(filename, "r") as f:
        for line in f:
            if line.startswith("Quantity") or line.strip() == "" or line.startswith("-"):
                continue
            m = pattern.match(line)
            if m:
                name, value_str, uncert_str, unit = m.groups()
                try:
                    value = float(value_str.replace("...", "").replace(" ", ""))
                    uncertainty = 0.0 if uncert_str == "exact" else float(uncert_str.replace("...", "").replace(" ", ""))
                    constants.append({
                        "name": name.strip(),
                        "value": value,
                        "uncertainty": uncertainty,
                        "unit": unit.strip()
                    })
                except Exception as e:
                    logging.warning(f"Failed parsing line: {line.strip()} - {e}")
                    continue
    return pd.DataFrame(constants)

def check_physical_consistency(df_results):
    bad_data = []
    relations = [
        ('Planck constant', 'reduced Planck constant', lambda x, y: abs(x['scale'] / y['scale'] - 2 * np.pi), 0.1, 'scale ratio vs. 2π'),
        ('proton mass', 'proton-electron mass ratio', lambda x, y: abs(x['n'] - y['n'] - np.log10(1836)), 0.5, 'n difference vs. log(proton-electron ratio)'),
        ('Fermi coupling constant', 'weak mixing angle', lambda x, y: abs(x['scale'] - y['scale'] / np.sqrt(2)), 0.1, 'scale vs. sin²θ_W/√2'),
        ('tau energy equivalent', 'tau mass energy equivalent in MeV', lambda x, y: abs(x['value'] - y['value']), 0.01, 'value consistency')
    ]
    for name1, name2, check_func, threshold, reason in relations:
        try:
            row1 = df_results[df_results['name'] == name1].iloc[0]
            row2 = df_results[df_results['name'] == name2].iloc[0]
            if check_func(row1, row2) > threshold:
                bad_data.append((name1, f"Physical inconsistency: {reason}"))
                bad_data.append((name2, f"Physical inconsistency: {reason}"))
        except IndexError:
            continue
    return bad_data

def symbolic_fit_all_constants(df, base=2, Omega=1.0):
    logging.info("Starting symbolic fit for all constants...")
    results = []
    def process_constant(row):
        try:
            result = invert_D(row['value'], base=base, Omega=Omega)
            if result is None:
                logging.error(f"Failed inversion for {row['name']}: {row['value']}")
                return None
            n, beta, scale, emergent_uncertainty, r_local, k_local = result
            approx = D(n, beta, r_local, k_local, Omega, base, scale)
            if approx is None:
                return None
            error = abs(approx - row['value'])
            rel_error = error / max(abs(row['value']), 1e-30)
            return {
                'name': row['name'], 'value': row['value'], 'unit': row['unit'],
                'n': n, 'beta': beta, 'approx': approx, 'error': error,
                'rel_error': rel_error, 'uncertainty': row['uncertainty'],
                'emergent_uncertainty': emergent_uncertainty, 'r_local': r_local,
                'k_local': k_local, 'scale': scale
            }
        except Exception as e:
            logging.error(f"Error processing {row['name']}: {e}")
            return None

    # Parallel processing with tqdm
    results = Parallel(n_jobs=-1, timeout=15, backend='loky', maxtasksperchild=100)(
        delayed(process_constant)(row) for row in tqdm(df.to_dict('records'), desc="Fitting constants")
    )
    results = [r for r in results if r is not None]
    df_results = pd.DataFrame(results)

    # IQR outlier detection
    if not df_results.empty:
        df_results['bad_data'] = False
        df_results['bad_data_reason'] = ''
        for name in df_results['name'].unique():
            mask = df_results['name'] == name
            if df_results.loc[mask, 'uncertainty'].notnull().any():
                uncertainties = df_results.loc[mask, 'uncertainty'].dropna()
                if not uncertainties.empty:
                    Q1, Q3 = np.percentile(uncertainties, [25, 75])
                    IQR = Q3 - Q1
                    outlier_mask = (uncertainties < Q1 - 1.5 * IQR) | (uncertainties > Q3 + 1.5 * IQR)
                    if outlier_mask.any():
                        df_results.loc[mask & df_results['uncertainty'].isin(uncertainties[outlier_mask]), 'bad_data'] = True
                        df_results.loc[mask & df_results['uncertainty'].isin(uncertainties[outlier_mask]), 'bad_data_reason'] += 'Uncertainty outlier; '

        # Flag high relative uncertainty and emergent deviation
        high_rel_error_mask = df_results['rel_error'] > 0.5
        df_results.loc[high_rel_error_mask, 'bad_data'] = True
        df_results.loc[high_rel_error_mask, 'bad_data_reason'] += df_results.loc[high_rel_error_mask, 'rel_error'].apply(lambda x: f"High relative uncertainty ({x:.2e} > 0.5); ")

        high_uncertainty_mask = df_results['uncertainty'] > 2 * df_results['emergent_uncertainty']
        df_results.loc[high_uncertainty_mask, 'bad_data'] = True
        df_results.loc[high_uncertainty_mask, 'bad_data_reason'] += df_results.loc[high_uncertainty_mask].apply(
            lambda row: f"Uncertainty deviates from emergent ({row['uncertainty']:.2e} vs. {row['emergent_uncertainty']:.2e}); ", axis=1)

        # Physical consistency checks
        bad_data = check_physical_consistency(df_results)
        for name, reason in bad_data:
            df_results.loc[df_results['name'] == name, 'bad_data'] = True
            df_results.loc[df_results['name'] == name, 'bad_data_reason'] += reason + '; '

    logging.info("Symbolic fit completed.")
    return df_results

def total_error(params, df_subset, constants):
    base, Omega = params
    df_results = symbolic_fit_all_constants(df_subset, base, Omega)
    if df_results.empty:
        return np.inf
    error = df_results['error'].mean()
    return error if np.isfinite(error) else np.inf

def main():
    start_time = time.time()
    df = parse_codata_ascii("allascii.txt")
    logging.info(f"Parsed {len(df)} constants")

    # Optimize base and Omega
    subset_df = df.sample(n=min(10, len(df)), random_state=42)
    result = minimize(total_error, [2, 1.0], args=(subset_df, df), method='Nelder-Mead', options={'maxiter': 20})
    base, Omega = result.x
    logging.info(f"Optimized parameters: base={base}, Omega={Omega}")

    # Run final fit
    df_results = symbolic_fit_all_constants(df, base, Omega)
    if not df_results.empty:
        df_results.to_csv("symbolic_fit_results_emergent_fixed.txt", index=False)
        logging.info(f"Saved results to symbolic_fit_results_emergent_fixed.txt")
    else:
        logging.error("No results to save")

    logging.info(f"Total runtime: {time.time() - start_time:.2f} seconds")

if __name__ == "__main__":
    print("Parsing CODATA constants from allascii.txt...")
    codata_df = parse_codata_ascii("allascii.txt")
    print(f"Parsed {len(codata_df)} constants.")

    # Use a larger subset for optimization
    subset_df = codata_df.head(50)
    init_params = [1.0, 1.0, 1.0, 2.0, 1.0]
    bounds = [(1e-5, 10), (1e-5, 10), (1e-5, 10), (1.5, 10), (1e-5, 100)]

    print("Optimizing symbolic model parameters...")
    res = minimize(total_error, init_params, args=(subset_df,), bounds=bounds, method='L-BFGS-B', options={'maxiter': 100})
    r_opt, k_opt, Omega_opt, base_opt, scale_opt = res.x
    print(f"Optimization complete. Found parameters:\nr = {r_opt:.6f}, k = {k_opt:.6f}, Omega = {Omega_opt:.6f}, base = {base_opt:.6f}, scale = {scale_opt:.6f}")

    print("Fitting symbolic dimensions to all constants...")
    fitted_df = symbolic_fit_all_constants(codata_df, r=r_opt, k=k_opt, Omega=Omega_opt, base=base_opt, scale=scale_opt, max_n=500, steps=500)
    fitted_df_sorted = fitted_df.sort_values("error")

    print("\nTop 20 best symbolic fits:")
    print(fitted_df_sorted.head(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nTop 20 worst symbolic fits:")
    print(fitted_df_sorted.tail(20)[['name', 'value', 'unit', 'n', 'beta', 'approx', 'error', 'uncertainty', 'scale', 'bad_data', 'bad_data_reason']].to_string(index=False))

    print("\nPotentially bad data constants summary:")
    bad_data_df = fitted_df[fitted_df['bad_data'] == True][['name', 'value', 'error', 'rel_error', 'uncertainty', 'bad_data_reason']]
    print(bad_data_df.to_string(index=False))

    fitted_df_sorted.to_csv("symbolic_fit_results.txt", sep="\t", index=False)

    plt.figure(figsize=(10, 5))
    plt.hist(fitted_df_sorted['error'], bins=50, color='skyblue', edgecolor='black')
    plt.title('Histogram of Absolute Errors in Symbolic Fit')
    plt.xlabel('Absolute Error')
    plt.ylabel('Count')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.scatter(fitted_df_sorted['n'], fitted_df_sorted['error'], alpha=0.5, s=15, c='orange', edgecolors='black')
    plt.title('Absolute Error vs Symbolic Dimension n')
    plt.xlabel('n')
    plt.ylabel('Absolute Error')
    plt.grid(True)
    plt.tight_layout()
    plt.show()